{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import collections\n",
    "from sklearn.cross_validation import train_test_split\n",
    "from tensor2tensor.utils import beam_search, rouge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset/ctexts.json','r') as fopen:\n",
    "    ctexts = json.load(fopen)\n",
    "    \n",
    "with open('dataset/headlines.json','r') as fopen:\n",
    "    headlines = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.decomposition import TruncatedSVD\n",
    "\n",
    "max_len = 300\n",
    "\n",
    "def topic_modelling(string, n = 200):\n",
    "    vectorizer = TfidfVectorizer()\n",
    "    tf = vectorizer.fit_transform([string])\n",
    "    tf_features = vectorizer.get_feature_names()\n",
    "    compose = TruncatedSVD(1).fit(tf)\n",
    "    return ' '.join([tf_features[i] for i in compose.components_[0].argsort()[: -n - 1 : -1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/sklearn/decomposition/truncated_svd.py:192: RuntimeWarning: invalid value encountered in true_divide\n",
      "  self.explained_variance_ratio_ = exp_var / full_var\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 18.8 s, sys: 24.1 s, total: 42.9 s\n",
      "Wall time: 10.9 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "h, c = [], []\n",
    "for i in range(len(ctexts)):\n",
    "    try:\n",
    "        c.append(topic_modelling(ctexts[i], max_len))\n",
    "        h.append(headlines[i])\n",
    "    except:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(words, n_words):\n",
    "    count = [['PAD', 0], ['GO', 1], ['EOS', 2], ['UNK', 3]]\n",
    "    count.extend(collections.Counter(words).most_common(n_words))\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 47609\n",
      "Most common words [('dot', 4394), ('the', 4379), ('comma', 4349), ('to', 4280), ('in', 4268), ('of', 4262)]\n",
      "Sample data [5, 7, 4, 6, 10, 9, 11, 8, 15, 1990] ['the', 'to', 'dot', 'comma', 'and', 'of', 'on', 'in', 'was', 'festival']\n",
      "filtered vocab size: 47613\n",
      "% of vocab used: 100.01%\n"
     ]
    }
   ],
   "source": [
    "concat = ' '.join(c).split()\n",
    "vocabulary_size = len(list(set(concat)))\n",
    "data, count, dictionary, rev_dictionary = build_dataset(concat, vocabulary_size)\n",
    "print('vocab from size: %d'%(vocabulary_size))\n",
    "print('Most common words', count[4:10])\n",
    "print('Sample data', data[:10], [rev_dictionary[i] for i in data[:10]])\n",
    "print('filtered vocab size:',len(dictionary))\n",
    "print(\"% of vocab used: {}%\".format(round(len(dictionary)/vocabulary_size,4)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'daman and diu revokes mandatory rakshabandhan in offices order EOS'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for i in range(len(h)):\n",
    "    h[i] = h[i] + ' EOS'\n",
    "h[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary['GO']\n",
    "PAD = dictionary['PAD']\n",
    "EOS = dictionary['EOS']\n",
    "UNK = dictionary['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def str_idx(corpus, dic, UNK=3):\n",
    "    X = []\n",
    "    for i in corpus:\n",
    "        ints = []\n",
    "        for k in i.split():\n",
    "            ints.append(dic.get(k, UNK))\n",
    "        X.append(ints)\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = str_idx(c, dictionary)\n",
    "Y = str_idx(h, dictionary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X, test_X, train_Y, test_Y = train_test_split(X, Y, test_size = 0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def position_encoding(inputs):\n",
    "    T = tf.shape(inputs)[1]\n",
    "    repr_dim = inputs.get_shape()[-1].value\n",
    "    pos = tf.reshape(tf.range(0.0, tf.to_float(T), dtype=tf.float32), [-1, 1])\n",
    "    i = np.arange(0, repr_dim, 2, np.float32)\n",
    "    denom = np.reshape(np.power(10000.0, i / repr_dim), [1, -1])\n",
    "    enc = tf.expand_dims(tf.concat([tf.sin(pos / denom), tf.cos(pos / denom)], 1), 0)\n",
    "    return tf.tile(enc, [tf.shape(inputs)[0], 1, 1])\n",
    "\n",
    "def layer_norm(inputs, epsilon=1e-8):\n",
    "    mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)\n",
    "    normalized = (inputs - mean) / (tf.sqrt(variance + epsilon))\n",
    "    params_shape = inputs.get_shape()[-1:]\n",
    "    gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())\n",
    "    beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())\n",
    "    return gamma * normalized + beta\n",
    "\n",
    "\n",
    "def cnn_block(x, dilation_rate, pad_sz, hidden_dim, kernel_size):\n",
    "    x = layer_norm(x)\n",
    "    pad = tf.zeros([tf.shape(x)[0], pad_sz, hidden_dim])\n",
    "    x =  tf.layers.conv1d(inputs = tf.concat([pad, x, pad], 1),\n",
    "                          filters = hidden_dim,\n",
    "                          kernel_size = kernel_size,\n",
    "                          dilation_rate = dilation_rate)\n",
    "    x = x[:, :-pad_sz, :]\n",
    "    x = tf.nn.relu(x)\n",
    "    return x\n",
    "\n",
    "class Summarization:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, \n",
    "                 dict_size, learning_rate, \n",
    "                 kernel_size = 2, n_attn_heads = 16):\n",
    "\n",
    "        self.X = tf.placeholder(tf.int32, [None, max_len])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        self.batch_size = batch_size\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        \n",
    "        self.embedding = tf.Variable(tf.random_uniform([dict_size, embedded_size], -1, 1))\n",
    "        \n",
    "        self.num_layers = num_layers\n",
    "        self.kernel_size = kernel_size\n",
    "        self.size_layer = size_layer\n",
    "        self.n_attn_heads = n_attn_heads\n",
    "        self.dict_size = dict_size\n",
    "        \n",
    "        self.training_logits, coverage_loss = self.forward(self.X, decoder_input)\n",
    "        maxlen = tf.reduce_max(self.Y_seq_len)\n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        \n",
    "        targets = tf.slice(self.Y, [0, 0], [-1, maxlen])\n",
    "        i1, i2 = tf.meshgrid(tf.range(batch_size),\n",
    "                     tf.range(maxlen), indexing=\"ij\")\n",
    "        indices = tf.stack((i1,i2,targets),axis=2)\n",
    "        probs = tf.gather_nd(self.training_logits, indices)\n",
    "        probs = tf.where(tf.less_equal(probs,0),tf.ones_like(probs)*1e-10,probs)\n",
    "        crossent = -tf.log(probs)\n",
    "        self.cost = tf.reduce_sum(crossent * masks) / tf.to_float(batch_size)\n",
    "        self.coverage_loss = tf.reduce_sum(coverage_loss / tf.to_float(batch_size))\n",
    "        self.cost = self.cost + self.coverage_loss\n",
    "        \n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n",
    "        \n",
    "        \n",
    "    def _calc_final_dist(self, x, gens, vocab_dists, attn_dists):\n",
    "        with tf.variable_scope('final_distribution', reuse=tf.AUTO_REUSE):\n",
    "            vocab_dists = gens * vocab_dists\n",
    "            attn_dists = (1-gens) * attn_dists\n",
    "            batch_size = tf.shape(attn_dists)[0]\n",
    "            dec_t = tf.shape(attn_dists)[1]\n",
    "            attn_len = tf.shape(attn_dists)[2]\n",
    "\n",
    "            dec = tf.range(0, limit=dec_t) # [dec]\n",
    "            dec = tf.expand_dims(dec, axis=-1) # [dec, 1]\n",
    "            dec = tf.tile(dec, [1, attn_len]) # [dec, atten_len]\n",
    "            dec = tf.expand_dims(dec, axis=0) # [1, dec, atten_len]\n",
    "            dec = tf.tile(dec, [batch_size, 1, 1]) # [batch_size, dec, atten_len]\n",
    "\n",
    "            x = tf.expand_dims(x, axis=1) # [batch_size, 1, atten_len]\n",
    "            x = tf.tile(x, [1, dec_t, 1]) # [batch_size, dec, atten_len]\n",
    "            x = tf.stack([dec, x], axis=3)\n",
    "\n",
    "            attn_dists_projected = tf.map_fn(fn=lambda y: \\\n",
    "                                            tf.scatter_nd(y[0], y[1], [dec_t, self.dict_size]),\n",
    "                                            elems=(x, attn_dists), dtype=tf.float32)\n",
    "\n",
    "            final_dists = attn_dists_projected + vocab_dists\n",
    "            return final_dists\n",
    "        \n",
    "    def forward(self, x, y, reuse = False):\n",
    "        with tf.variable_scope('forward',reuse=reuse):\n",
    "            with tf.variable_scope('forward',reuse=reuse):\n",
    "                encoder_embedded = tf.nn.embedding_lookup(self.embedding, x)\n",
    "                decoder_embedded = tf.nn.embedding_lookup(self.embedding, y)\n",
    "\n",
    "                encoder_embedded += position_encoding(encoder_embedded)\n",
    "                decoder_embedded += position_encoding(decoder_embedded)\n",
    "                \n",
    "                for i in range(self.num_layers): \n",
    "                    dilation_rate = 2 ** i\n",
    "                    pad_sz = (self.kernel_size - 1) * dilation_rate \n",
    "                    with tf.variable_scope('block_%d'%i,reuse=reuse):\n",
    "                        encoder_embedded += cnn_block(encoder_embedded, dilation_rate, \n",
    "                                                      pad_sz, self.size_layer, self.kernel_size)\n",
    "                \n",
    "                g = tf.identity(decoder_embedded)\n",
    "                dec = decoder_embedded\n",
    "                attn_dists = []\n",
    "                for i in range(self.num_layers):\n",
    "                    dilation_rate = 2 ** i\n",
    "                    pad_sz = (self.kernel_size - 1) * dilation_rate\n",
    "                    with tf.variable_scope('decode_%d'%i,reuse=reuse):\n",
    "                        attn_res = h = cnn_block(dec, dilation_rate, \n",
    "                                                 pad_sz, self.size_layer, self.kernel_size)\n",
    "                        C = []\n",
    "                        for j in range(self.n_attn_heads):\n",
    "                            h_ = tf.layers.dense(h, self.size_layer//self.n_attn_heads)\n",
    "                            g_ = tf.layers.dense(g, self.size_layer//self.n_attn_heads)\n",
    "                            zu_ = tf.layers.dense(encoder_embedded, self.size_layer//self.n_attn_heads)\n",
    "                            ze_ = tf.layers.dense(encoder_embedded, self.size_layer//self.n_attn_heads)\n",
    "\n",
    "                            d = tf.layers.dense(h_, self.size_layer//self.n_attn_heads) + g_\n",
    "                            dz = tf.matmul(d, tf.transpose(zu_, [0, 2, 1]))\n",
    "                            a = tf.nn.softmax(dz)\n",
    "                            attn_dists.append(a)\n",
    "                            c_ = tf.matmul(a, ze_)\n",
    "                            C.append(c_)\n",
    "\n",
    "                        c = tf.concat(C, 2)\n",
    "                        h = tf.layers.dense(attn_res + c, self.size_layer)\n",
    "                        dec += h\n",
    "                \n",
    "\n",
    "                weights = tf.transpose(self.embedding)\n",
    "                logits = tf.einsum('ntd,dk->ntk', dec, weights)\n",
    "                print(decoder_embedded, dec, attn_dists[-1])\n",
    "\n",
    "                with tf.variable_scope(\"gen\", reuse=tf.AUTO_REUSE):\n",
    "                    gens = tf.layers.dense(tf.concat([decoder_embedded, dec, attn_dists[-1]], axis=-1), \n",
    "                                               units=1, activation=tf.sigmoid, use_bias=False)\n",
    "                    \n",
    "                logits = tf.nn.softmax(logits)\n",
    "                \n",
    "                print(gens)\n",
    "                alignment_history = tf.transpose(attn_dists[-1],[1,2,0])\n",
    "                coverage_loss = tf.minimum(alignment_history,tf.cumsum(alignment_history, axis=2, exclusive=True))\n",
    "                        \n",
    "                return self._calc_final_dist(x, gens, logits, attn_dists[-1]), coverage_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 256\n",
    "num_layers = 4\n",
    "embedded_size = 256\n",
    "learning_rate = 1e-3\n",
    "batch_size = 16\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def beam_search_decoding(length = 20, beam_width = 5):\n",
    "    initial_ids = tf.fill([model.batch_size], GO)\n",
    "    \n",
    "    def symbols_to_logits(ids):\n",
    "        x = tf.contrib.seq2seq.tile_batch(model.X, beam_width)\n",
    "        logits, _ = model.forward(x, ids, reuse = True)\n",
    "        return logits[:, tf.shape(ids)[1]-1, :]\n",
    "\n",
    "    final_ids, final_probs = beam_search.beam_search(\n",
    "        symbols_to_logits,\n",
    "        initial_ids,\n",
    "        beam_width,\n",
    "        length,\n",
    "        len(dictionary),\n",
    "        0.0,\n",
    "        eos_id = EOS)\n",
    "    \n",
    "    return final_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"forward/forward/add_1:0\", shape=(?, ?, 256), dtype=float32) Tensor(\"forward/forward/decode_3/add_19:0\", shape=(?, ?, 256), dtype=float32) Tensor(\"forward/forward/decode_3/Reshape_31:0\", shape=(?, ?, 300), dtype=float32)\n",
      "Tensor(\"forward/forward/gen/dense/Sigmoid:0\", shape=(?, ?, 1), dtype=float32)\n",
      "Tensor(\"while/forward/forward/add_1:0\", shape=(?, ?, 256), dtype=float32) Tensor(\"while/forward/forward/decode_3/add_19:0\", shape=(?, ?, 256), dtype=float32) Tensor(\"while/forward/forward/decode_3/Reshape_31:0\", shape=(?, ?, 300), dtype=float32)\n",
      "Tensor(\"while/forward/forward/gen/dense/Sigmoid:0\", shape=(?, ?, 1), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Summarization(size_layer, num_layers, embedded_size, \n",
    "                      len(dictionary), learning_rate)\n",
    "model.generate = beam_search_decoding()\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int, maxlen = None):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    if maxlen:\n",
    "        max_sentence_len = maxlen\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return padded_seqs, seq_lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [01:00<00:00,  3.36it/s, accuracy=0.0533, cost=117, rouge_2=0.00606] \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:02<00:00,  4.28it/s, accuracy=0.0246, cost=96.1, rouge_2=0]     \n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, avg loss: 141.292347, avg accuracy: 0.024805\n",
      "epoch: 0, avg loss test: 126.745460, avg accuracy test: 0.037240\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [00:56<00:00,  4.63it/s, accuracy=0.0798, cost=112, rouge_2=0.0116] \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:01<00:00, 10.43it/s, accuracy=0.0606, cost=124, rouge_2=0.00641]\n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, avg loss: 125.585923, avg accuracy: 0.056561\n",
      "epoch: 1, avg loss test: 124.742002, avg accuracy test: 0.056867\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [00:57<00:00,  4.44it/s, accuracy=0.0809, cost=129, rouge_2=0.011]  \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:01<00:00, 10.39it/s, accuracy=0.0547, cost=117, rouge_2=0.00725]\n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, avg loss: 121.680643, avg accuracy: 0.084453\n",
      "epoch: 2, avg loss test: 124.057311, avg accuracy test: 0.060970\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [00:56<00:00,  4.58it/s, accuracy=0.136, cost=120, rouge_2=0.0263]  \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:01<00:00, 10.48it/s, accuracy=0.0714, cost=110, rouge_2=0.00725]\n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, avg loss: 117.770769, avg accuracy: 0.114314\n",
      "epoch: 3, avg loss test: 124.797237, avg accuracy test: 0.062626\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [00:56<00:00,  4.18it/s, accuracy=0.165, cost=126, rouge_2=0.0502]  \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:01<00:00, 10.55it/s, accuracy=0.0597, cost=122, rouge_2=0.00725]\n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, avg loss: 113.329572, avg accuracy: 0.160522\n",
      "epoch: 4, avg loss test: 127.027611, avg accuracy test: 0.062630\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [00:56<00:00,  4.63it/s, accuracy=0.267, cost=102, rouge_2=0.0756] \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:01<00:00, 10.34it/s, accuracy=0.0301, cost=139, rouge_2=0]      \n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, avg loss: 108.412371, avg accuracy: 0.220881\n",
      "epoch: 5, avg loss test: 128.815912, avg accuracy test: 0.060748\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [00:57<00:00,  4.47it/s, accuracy=0.293, cost=91, rouge_2=0.0966]  \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:01<00:00, 10.45it/s, accuracy=0.063, cost=110, rouge_2=0.00833] \n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 6, avg loss: 103.023471, avg accuracy: 0.297669\n",
      "epoch: 6, avg loss test: 133.076606, avg accuracy test: 0.052533\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [00:56<00:00,  4.29it/s, accuracy=0.343, cost=113, rouge_2=0.125]  \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:01<00:00, 10.37it/s, accuracy=0.0522, cost=144, rouge_2=0]      \n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 7, avg loss: 98.234655, avg accuracy: 0.380340\n",
      "epoch: 7, avg loss test: 139.086945, avg accuracy test: 0.057499\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [00:56<00:00,  4.66it/s, accuracy=0.469, cost=92.3, rouge_2=0.198] \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:01<00:00,  9.83it/s, accuracy=0.0597, cost=145, rouge_2=0]      \n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 8, avg loss: 94.417634, avg accuracy: 0.448188\n",
      "epoch: 8, avg loss test: 145.733629, avg accuracy test: 0.055748\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [00:56<00:00,  4.68it/s, accuracy=0.47, cost=96, rouge_2=0.177]   \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:01<00:00, 10.31it/s, accuracy=0.0336, cost=129, rouge_2=0]      "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 9, avg loss: 91.752299, avg accuracy: 0.499448\n",
      "epoch: 9, avg loss test: 152.279775, avg accuracy test: 0.060171\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "from sklearn.utils import shuffle\n",
    "import time\n",
    "\n",
    "for EPOCH in range(10):\n",
    "    lasttime = time.time()\n",
    "    total_loss, total_accuracy, total_loss_test, total_accuracy_test = 0, 0, 0, 0\n",
    "    train_X, train_Y = shuffle(train_X, train_Y)\n",
    "    test_X, test_Y = shuffle(test_X, test_Y)\n",
    "    pbar = tqdm(range(0, len(train_X), batch_size), desc='train minibatch loop')\n",
    "    for k in pbar:\n",
    "        batch_x, _ = pad_sentence_batch(train_X[k: min(k+batch_size,len(train_X))], PAD, maxlen = max_len)\n",
    "        batch_y, _ = pad_sentence_batch(train_Y[k: min(k+batch_size,len(train_X))], PAD)\n",
    "        l, acc, loss, _ = sess.run([model.training_logits, model.accuracy, model.cost, model.optimizer], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y})\n",
    "        total_loss += loss\n",
    "        total_accuracy += acc\n",
    "        r = rouge.rouge_n(np.argmax(l, axis = 2), batch_y)\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc, rouge_2 = r)\n",
    "        \n",
    "    pbar = tqdm(range(0, len(test_X), batch_size), desc='test minibatch loop')\n",
    "    for k in pbar:\n",
    "        batch_x, _ = pad_sentence_batch(test_X[k: min(k+batch_size,len(test_X))], PAD, maxlen = max_len)\n",
    "        batch_y, _ = pad_sentence_batch(test_Y[k: min(k+batch_size,len(test_X))], PAD)\n",
    "        l, acc, loss = sess.run([model.training_logits, model.accuracy, model.cost], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y})\n",
    "        total_loss_test += loss\n",
    "        total_accuracy_test += acc\n",
    "        r = rouge.rouge_n(np.argmax(l, axis = 2), batch_y)\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc, rouge_2 = r)\n",
    "        \n",
    "    total_loss /= (len(train_X) / batch_size)\n",
    "    total_accuracy /= (len(train_X) / batch_size)\n",
    "    total_loss_test /= (len(test_X) / batch_size)\n",
    "    total_accuracy_test /= (len(test_X) / batch_size)\n",
    "        \n",
    "    print('epoch: %d, avg loss: %f, avg accuracy: %f'%(EPOCH, total_loss, total_accuracy))\n",
    "    print('epoch: %d, avg loss test: %f, avg accuracy test: %f'%(EPOCH, total_loss_test, total_accuracy_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "b_x, _ = pad_sentence_batch([test_X[0]], PAD, maxlen = max_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 21)"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generated = sess.run(model.generate, feed_dict = {model.X: b_x})[0]\n",
    "generated.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GO india shastri is the rounder six comma six in people about come people returns shastri is six starting six teams\n",
      "GO india shastri is the rounder six comma six comma six in people returns shastri is six starting six teams shastri\n",
      "GO india shastri is the rounder six starting returns comma six in people returns shastri is six starting six teams shastri\n",
      "GO india shastri is the rounder six comma six in people about come people returns shastri is six starting six matured\n",
      "GO india shastri is the rounder six comma six in people about come people returns shastri is six starting six the\n"
     ]
    }
   ],
   "source": [
    "for g in generated:\n",
    "    print(' '.join([rev_dictionary[i] for i in g]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'shastris and kumbles will come and go ravi shastri EOS'"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "' '.join(rev_dictionary[i] for i in test_Y[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
